import torch
from ai.models.music_genre_classifier import MusicGenreClassifier

def test_model_output_shape():
    model = MusicGenreClassifier(num_classes=10)
    input_tensor = torch.randn(1, 1, 128, 128)  # 模拟梅尔频谱输入
    output = model(input_tensor)
    assert output.shape == (1, 10)